# coding=utf-8
# Copyright 2020 Google and DeepMind.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function

try:
    import transformers
except ModuleNotFoundError:
    from pip._internal import main as pip
    pip(['install', '--user', 'transformers'])

import argparse
from transformers import BertTokenizer, XLMTokenizer, XLMRobertaTokenizer
import os
from collections import defaultdict
import csv
import random
import os
import shutil
import json


TOKENIZERS = {
    'bert': BertTokenizer,
    'xlm': XLMTokenizer,
    'xlmr': XLMRobertaTokenizer,
}

def panx_tokenize_preprocess(args):
    def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len):
        if not os.path.exists(infile):
            print(f'{infile} not exists')
            return 0
        special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2
        max_seq_len = max_len - special_tokens_count
        subword_len_counter = idx = 0
        with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx:
            for line in fin:
                line = line.strip()
                if not line:
                    fout.write('\n')
                    fidx.write('\n')
                    idx += 1
                    subword_len_counter = 0
                    continue

                items = line.split()
                token = items[0].strip()
                if len(items) == 2:
                    label = items[1].strip()
                else:
                    label = 'O'
                current_subwords_len = len(tokenizer.tokenize(token))

                if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0:
                    token = tokenizer.unk_token
                    current_subwords_len = 1

                if (subword_len_counter + current_subwords_len) > max_seq_len:
                    fout.write(f"\n{token}\t{label}\n")
                    fidx.write(f"\n{idx}\n")
                    subword_len_counter = current_subwords_len
                else:
                    fout.write(f"{token}\t{label}\n")
                    fidx.write(f"{idx}\n")
                    subword_len_counter += current_subwords_len
        return 1

    model_type = args.model_type
    tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path,
                                                       do_lower_case=args.do_lower_case,
                                                       cache_dir=args.cache_dir if args.cache_dir else None)
    for lang in args.languages.split(','):
        out_dir = os.path.join(args.output_dir, lang)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        if lang == 'en':
            files = ['dev', 'test', 'train']
        else:
            files = ['dev', 'test', 'train']
        for file in files:
            infile = os.path.join(args.data_dir, f'{file}-{lang}.tsv')
            outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path))
            idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path))
            if os.path.exists(outfile) and os.path.exists(idxfile):
                print(f'{outfile} and {idxfile} exist')
            else:
                code = _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len)
                if code > 0:
                    print(f'finish preprocessing {outfile}')


def panx_preprocess(args):
    def _process_one_file(infile, outfile):
        lines = open(infile, 'r').readlines()
        if lines[-1].strip() == '':
            lines = lines[:-1]
        with open(outfile, 'w') as fout:
            for l in lines:
                items = l.strip().split('\t')
                if len(items) == 2:
                    label = items[1].strip()
                    idx = items[0].find(':')
                    if idx != -1:
                        token = items[0][idx+1:].strip()
                        # if 'test' in infile:
                        #   fout.write(f'{token}\n')
                        # else:
                        fout.write(f'{token}\t{label}\n')
                else:
                    fout.write('\n')
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    langs = 'ar he vi id jv ms tl eu ml ta te af nl en de el bn hi mr ur fa fr it pt es bg ru ja ka ko th sw yo my zh kk tr et fi hu'.split(' ')
    for lg in langs:
        for split in ['train', 'test', 'dev']:
            infile = os.path.join(args.data_dir, f'{lg}-{split}')
            outfile = os.path.join(args.output_dir, f'{split}-{lg}.tsv')
            _process_one_file(infile, outfile)

def udpos_tokenize_preprocess(args):
    def _preprocess_one_file(infile, outfile, idxfile, tokenizer, max_len):
        if not os.path.exists(infile):
            print(f'{infile} does not exist')
            return
        subword_len_counter = idx = 0
        special_tokens_count = 3 if isinstance(tokenizer, XLMRobertaTokenizer) else 2
        max_seq_len = max_len - special_tokens_count
        with open(infile, "rt") as fin, open(outfile, "w") as fout, open(idxfile, "w") as fidx:
            for line in fin:
                line = line.strip()
                if len(line) == 0 or line == '':
                    fout.write('\n')
                    fidx.write('\n')
                    idx += 1
                    subword_len_counter = 0
                    continue

                items = line.split()
                if len(items) == 2:
                    label = items[1].strip()
                else:
                    label = "X"
                token = items[0].strip()
                current_subwords_len = len(tokenizer.tokenize(token))

                if (current_subwords_len == 0 or current_subwords_len > max_seq_len) and len(token) != 0:
                    token = tokenizer.unk_token
                    current_subwords_len = 1

                if (subword_len_counter + current_subwords_len) > max_seq_len:
                    fout.write(f"\n{token}\t{label}\n")
                    fidx.write(f"\n{idx}\n")
                    subword_len_counter = current_subwords_len
                else:
                    fout.write(f"{token}\t{label}\n")
                    fidx.write(f"{idx}\n")
                    subword_len_counter += current_subwords_len

    model_type = args.model_type
    tokenizer = TOKENIZERS[model_type].from_pretrained(args.model_name_or_path,
                                                       do_lower_case=args.do_lower_case,
                                                       cache_dir=args.cache_dir if args.cache_dir else None)
    for lang in args.languages.split(','):
        if (os.path.exists(os.path.join(args.data_dir, "train-{}.tsv".format(lang)))==False) or \
           (os.path.exists(os.path.join(args.data_dir, "dev-{}.tsv".format(lang)))==False) or \
           (os.path.exists(os.path.join(args.data_dir, "test-{}.tsv".format(lang)))==False):
            continue
        out_dir = os.path.join(args.output_dir, lang)
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
#         if lang == 'en':
#             files = ['dev', 'test', 'train']
#         else:
        files = ['dev', 'test', 'train']
        for file in files:
            infile = os.path.join(args.data_dir, "{}-{}.tsv".format(file, lang))
            outfile = os.path.join(out_dir, "{}.{}".format(file, args.model_name_or_path))
            idxfile = os.path.join(out_dir, "{}.{}.idx".format(file, args.model_name_or_path))
            if os.path.exists(outfile) and os.path.exists(idxfile):
                print(f'{outfile} and {idxfile} exist')
            else:
                _preprocess_one_file(infile, outfile, idxfile, tokenizer, args.max_len)
                print(f'finish preprocessing {outfile}')

def udpos_preprocess(args):
    def _read_one_file(file):
        data = []
        sent, tag, lines = [], [], []
        for line in open(file, 'r'):
            items = line.strip().split('\t')
            if len(items) != 10:
                empty = all(w == '_' for w in sent)
                num_empty = sum([int(w == '_') for w in sent])
                if num_empty == 0 or num_empty < len(sent) - 1:
                    data.append((sent, tag, lines))
                sent, tag, lines = [], [], []
            else:
                sent.append(items[1].strip())
                tag.append(items[3].strip())
                lines.append(line.strip())
                assert len(sent) == int(items[0]), 'line={}, sent={}, tag={}'.format(line, sent, tag)
        return data

    def isfloat(value):
        try:
            float(value)
            return True
        except ValueError:
            return False

    def remove_empty_space(data):
        new_data = {}
        for split in data:
            new_data[split] = []
            for sent, tag, lines in data[split]:
                new_sent = [''.join(w.replace('\u200c', '').split(' ')) for w in sent]
                lines = [line.replace('\u200c', '') for line in lines]
                assert len(" ".join(new_sent).split(' ')) == len(tag)
                new_data[split].append((new_sent, tag, lines))
        return new_data

    def check_file(file):
        for i, l in enumerate(open(file)):
            items = l.strip().split('\t')
            assert len(items[0].split(' ')) == len(items[1].split(' ')), 'idx={}, line={}'.format(i, l)

    def _write_files(data, output_dir, lang, suffix):
        for split in data:
            if len(data[split]) > 0:
                prefix = os.path.join(output_dir, f'{split}-{lang}')
                if suffix == 'mt':
                    with open(prefix + '.mt.tsv', 'w') as fout:
                        for idx, (sent, tag, _) in enumerate(data[split]):
                            newline = '\n' if idx != len(data[split]) - 1 else ''
                            if split == 'test':
                                fout.write('{}{}'.format(' '.join(sent, newline)))
                            else:
                                fout.write('{}\t{}{}'.format(' '.join(sent), ' '.join(tag), newline))
                    check_file(prefix + '.mt.tsv')
                    print('  - finish checking ' + prefix + '.mt.tsv')
                elif suffix == 'tsv':
                    with open(prefix + '.tsv', 'w') as fout:
                        for sidx, (sent, tag, _) in enumerate(data[split]):
                            for widx, (w, t) in enumerate(zip(sent, tag)):
                                newline = '' if (sidx == len(data[split]) - 1) and (widx == len(sent) - 1) else '\n'
#                                 if split == 'test':
#                                     fout.write('{}{}'.format(w, newline))
#                                 else:
                                fout.write('{}\t{}{}'.format(w, t, newline))
                            fout.write('\n')
                elif suffix == 'conll':
                    with open(prefix + '.conll', 'w') as fout:
                        for _, _, lines in data[split]:
                            for l in lines:
                                fout.write(l.strip() + '\n')
                            fout.write('\n')
                print(f'finish writing file to {prefix}.{suffix}')

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    languages = 'af ar bg de el en es et eu fa fi fr he hi hu id it ja kk ko mr nl pt ru ta te th tl tr ur vi yo zh'.split(' ')
    for root, dirs, files in os.walk(args.data_dir):
        lg = root.strip().split('/')[-1]
        if root == args.data_dir or lg not in languages:
            continue

        data = {k: [] for k in ['train', 'dev', 'test']}
        for f in sorted(files):
            if f.endswith('conll'):
                file = os.path.join(root, f)
                examples = _read_one_file(file)
                if 'train' in f:
                    data['train'].extend(examples)
                elif 'dev' in f:
                    data['dev'].extend(examples)
                elif 'test' in f:
                    data['test'].extend(examples)
                else:
                    print('split not found: ', file)
                print(' - finish reading {}, {}'.format(file, [(k, len(v)) for k,v in data.items()]))

        data = remove_empty_space(data)
        for sub in ['tsv']:
            _write_files(data, args.output_dir, lg, sub)

def pawsx_preprocess(args):
    def _preprocess_one_file(infile, outfile, remove_label=False):
        data = []
        for i, line in enumerate(open(infile, 'r')):
            if i == 0:
                continue
            items = line.strip().split('\t')
            sent1 = ' '.join(items[1].strip().split(' '))
            sent2 = ' '.join(items[2].strip().split(' '))
            label = items[3]
            data.append([sent1, sent2, label])

        with open(outfile, 'w') as fout:
            writer = csv.writer(fout, delimiter='\t')
            for sent1, sent2, label in data:
#                 if remove_label:
#                     writer.writerow([sent1, sent2])
#                 else:
                writer.writerow([sent1, sent2, label])

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    split2file = {'train': 'train', 'test': 'test_2k', 'dev': 'dev_2k'}
    for lang in ['en', 'de', 'es', 'fr', 'ja', 'ko', 'zh']:
        for split in ['train', 'test', 'dev']:
#             if split == 'train' and lang != 'en':
#                 continue
            if split == 'train':
                if lang != 'en':
                    split2file['train'] = 'translated_train'
                else:
                    split2file['train'] = 'train'
            file = split2file[split]
            infile = os.path.join(args.data_dir, lang, "{}.tsv".format(file))
            outfile = os.path.join(args.output_dir, "{}-{}.tsv".format(split, lang))
            _preprocess_one_file(infile, outfile, remove_label=(split == 'test'))
            print(f'finish preprocessing {outfile}')

def xnli_preprocess(args):
    def _preprocess_file(infile, output_dir, split):
        all_langs = defaultdict(list)
        for i, line in enumerate(open(infile, 'r')):
            if i == 0:
                continue

            items = line.strip().split('\t')
            lang = items[0].strip()
            label = "contradiction" if items[1].strip() == "contradictory" else items[1].strip()
            sent1 = ' '.join(items[6].strip().split(' '))
            sent2 = ' '.join(items[7].strip().split(' '))
            all_langs[lang].append((sent1, sent2, label))
        print(f'# langs={len(all_langs)}')
        for lang, pairs in all_langs.items():
            outfile = os.path.join(output_dir, '{}-{}.tsv'.format(split, lang))
            with open(outfile, 'w') as fout:
                writer = csv.writer(fout, delimiter='\t')
                for (sent1, sent2, label) in pairs:
#                     if split == 'test':
#                         writer.writerow([sent1, sent2])
#                     else:
                    writer.writerow([sent1, sent2, label])
            print(f'finish preprocess {outfile}')

    def _preprocess_train_file(infile, outfile):
        with open(outfile, 'w') as fout:
            writer = csv.writer(fout, delimiter='\t')
            for i, line in enumerate(open(infile, 'r')):
                if i == 0:
                    continue

                items = line.strip().split('\t')
                sent1 = ' '.join(items[0].strip().split(' '))
                sent2 = ' '.join(items[1].strip().split(' '))
                label = "contradiction" if items[2].strip() == "contradictory" else items[2].strip()
                writer.writerow([sent1, sent2, label])
        print(f'finish preprocess {outfile}')
    
    for lg in ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh']:
        infile = os.path.join(args.data_dir, 'XNLI-MT-1.0/multinli/multinli.train.{}.tsv'.format(lg))
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        outfile = os.path.join(args.output_dir, 'train-{}.tsv'.format(lg))
        _preprocess_train_file(infile, outfile)

    for split in ['test', 'dev']:
        infile = os.path.join(args.data_dir, 'XNLI-1.0/xnli.{}.tsv'.format(split))
        print(f'reading file {infile}')
        _preprocess_file(infile, args.output_dir, split)


def tatoeba_preprocess(args):
    lang3_dict = {
        'afr':'af', 'ara':'ar', 'bul':'bg', 'ben':'bn',
        'deu':'de', 'ell':'el', 'spa':'es', 'est':'et',
        'eus':'eu', 'pes':'fa', 'fin':'fi', 'fra':'fr',
        'heb':'he', 'hin':'hi', 'hun':'hu', 'ind':'id',
        'ita':'it', 'jpn':'ja', 'jav':'jv', 'kat':'ka',
        'kaz':'kk', 'kor':'ko', 'mal':'ml', 'mar':'mr',
        'nld':'nl', 'por':'pt', 'rus':'ru', 'swh':'sw',
        'tam':'ta', 'tel':'te', 'tha':'th', 'tgl':'tl',
        'tur':'tr', 'urd':'ur', 'vie':'vi', 'cmn':'zh',
        'eng':'en',
    }
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    for sl3, sl2 in lang3_dict.items():
        if sl3 != 'eng':
            src_file = f'{args.data_dir}/tatoeba.{sl3}-eng.{sl3}'
            tgt_file = f'{args.data_dir}/tatoeba.{sl3}-eng.eng'
            src_out = f'{args.output_dir}/{sl2}-en.{sl2}'
            tgt_out = f'{args.output_dir}/{sl2}-en.en'
            shutil.copy(src_file, src_out)
            tgts = [l.strip() for l in open(tgt_file)]
            idx = range(len(tgts))
            data = zip(tgts, idx)
            with open(tgt_out, 'w') as ftgt:
                for t, i in sorted(data, key=lambda x: x[0]):
                    ftgt.write(f'{t}\n')


def xquad_preprocess(args):
    # Remove the test annotations to prevent accidental cheating
    remove_qa_test_annotations(args.data_dir)


def mlqa_preprocess(args):
    # Remove the test annotations to prevent accidental cheating
    remove_qa_test_annotations(args.data_dir)


def tydiqa_preprocess(args):
    LANG2ISO = {'arabic': 'ar', 'bengali': 'bn', 'english': 'en', 'finnish': 'fi',
                            'indonesian': 'id', 'korean': 'ko', 'russian': 'ru',
                            'swahili': 'sw', 'telugu': 'te'}
    assert os.path.exists(args.data_dir)
    train_file = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-train.json')
    os.makedirs(args.output_dir, exist_ok=True)

    # Split the training file into language-specific files
    lang2data = defaultdict(list)
    with open(train_file, 'r') as f_in:
        data = json.load(f_in)
        version = data['version']
        for doc in data['data']:
            for par in doc['paragraphs']:
                context = par['context']
                for qa in par['qas']:
                    question = qa['question']
                    question_id = qa['id']
                    example_lang = question_id.split('-')[0]
                    q_id = question_id.split('-')[-1]
                    for answer in qa['answers']:
                        a_start, a_text = answer['answer_start'], answer['text']
                        a_end = a_start + len(a_text)
                        assert context[a_start:a_end] == a_text
                    lang2data[example_lang].append({'paragraphs': [{
                            'context': context,
                            'qas': [{'answers': qa['answers'],
                                             'question': question,
                                             'id': q_id}]}]})

    for lang, data in lang2data.items():
        out_file = os.path.join(
                args.output_dir, 'tydiqa.%s.train.json' % LANG2ISO[lang])
        with open(out_file, 'w') as f:
            json.dump({'data': data, 'version': version}, f)

    # Rename the dev files
    dev_dir = os.path.join(args.data_dir, 'tydiqa-goldp-v1.1-dev')
    assert os.path.exists(dev_dir)
    for lang, iso in LANG2ISO.items():
        src_file = os.path.join(dev_dir, 'tydiqa-goldp-dev-%s.json' % lang)
        dst_file = os.path.join(dev_dir, 'tydiqa.%s.dev.json' % iso)
        os.rename(src_file, dst_file)

    # Remove the test annotations to prevent accidental cheating
    remove_qa_test_annotations(dev_dir)


def remove_qa_test_annotations(test_dir):
    assert os.path.exists(test_dir)
    for file_name in os.listdir(test_dir):
        new_data = []
        test_file = os.path.join(test_dir, file_name)
        with open(test_file, 'r') as f:
            data = json.load(f)
            version = data['version']
            for doc in data['data']:
                for par in doc['paragraphs']:
                    context = par['context']
                    for qa in par['qas']:
                        question = qa['question']
                        question_id = qa['id']
                        for answer in qa['answers']:
                            a_start, a_text = answer['answer_start'], answer['text']
                            a_end = a_start + len(a_text)
                            assert context[a_start:a_end] == a_text
                        new_data.append({'paragraphs': [{
                                'context': context,
#                                 'qas': [{'answers': [{'answer_start': 0, 'text': ''}],
                                'qas': [{'answers': [{'answer_start': a_start, 'text': ''}],
                                                 'question': question,
                                                 'id': question_id}]}]})
        with open(test_file, 'w') as f:
            json.dump({'data': new_data, 'version': version}, f)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                                            help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                                            help="The output data dir where any processed files will be written to.")
    parser.add_argument("--task", default="panx", type=str, required=True,
                                            help="The task name")
    parser.add_argument("--model_name_or_path", default="bert-base-multilingual-cased", type=str,
                                            help="The pre-trained model")
    parser.add_argument("--model_type", default="bert", type=str,
                                            help="model type")
    parser.add_argument("--max_len", default=512, type=int,
                                            help="the maximum length of sentences")
    parser.add_argument("--do_lower_case", action='store_true',
                                            help="whether to do lower case")
    parser.add_argument("--cache_dir", default=None, type=str,
                                            help="cache directory")
    parser.add_argument("--languages", default="en", type=str,
                                            help="process language")
    parser.add_argument("--remove_last_token", action='store_true',
                                            help="whether to remove the last token")
    parser.add_argument("--remove_test_label", action='store_true',
                                            help="whether to remove test set label")
    args = parser.parse_args()

    if args.task == 'panx_tokenize':
        panx_tokenize_preprocess(args)
    if args.task == 'panx':
        panx_preprocess(args)
    if args.task == 'udpos_tokenize':
        udpos_tokenize_preprocess(args)
    if args.task == 'udpos':
        udpos_preprocess(args)
    if args.task == 'pawsx':
        pawsx_preprocess(args)
    if args.task == 'xnli':
        xnli_preprocess(args)
    if args.task == 'tatoeba':
        tatoeba_preprocess(args)
    if args.task == 'xquad':
        xquad_preprocess(args)
    if args.task == 'mlqa':
        mlqa_preprocess(args)
    if args.task == 'tydiqa':
        tydiqa_preprocess(args)
